{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys, inspect, time\n",
    "\n",
    "import numpy as np\n",
    "import torch \n",
    "import matplotlib.pyplot as plt\n",
    "torch.multiprocessing.set_sharing_strategy('file_system')\n",
    "\n",
    "import discrepancy, visualization\n",
    "from algorithms import ABC_algorithms, TPABC, SMCABC, SMC2ABC, SNLABC, SNL2ABC\n",
    "from problems import problem_IS\n",
    "\n",
    "import utils_os, utils_math\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " sampling from true posterior ... \n",
      "\n",
      "x_obs= [[-1. -1. -1. -1.  1. -1. -1.  1.]\n",
      " [-1.  1. -1. -1. -1. -1. -1. -1.]\n",
      " [-1. -1. -1. -1.  1.  1. -1. -1.]\n",
      " [-1. -1. -1. -1. -1. -1.  1.  1.]\n",
      " [ 1.  1.  1. -1. -1. -1. -1. -1.]\n",
      " [ 1.  1.  1. -1. -1. -1. -1.  1.]\n",
      " [-1. -1. -1. -1. -1. -1. -1.  1.]\n",
      " [-1. -1. -1. -1.  1. -1. -1.  1.]]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAABjCAYAAACPO76VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAABqUlEQVR4nO3dTWrDMBQAYav0CO66voN0/xNYh+i6h3hZFtwUJBLnTZP5IBtjkM0g47/gEhGLGN6yN0A/jAFiDBBjgBgD5H1m5XVdY9u2kzblPnrvw+vWWtPGj4hyXFZmTm1ba7Hv+/iWJSjl1z7+6YzT+tHxr8XwMAViDBBjgBgDxBggxgAxBsjURd+ozHP97EcCI+O31q4ud2aAGAPEGCDGADEGiDFAjAFiDBBjgBgD5JTbIZm3JGZuxYx61P44M0CMAWIMEGOAGAPEGCDGADEGiDFATrkCz5T5MvOtnBkgxgAxBogxQIwBYgwQY4AYA8QYIFNX4L33h12NHmU+V7/32P4l4B8wBogxQIwBYgwQY4AYA8QYIMYAMQbI072QkO2W20XODBBjgBgDxBggxgAxBogxQIwBYgyQqRi11iUiUn6vwJkBYgwQY4AYA8QYIMYAMQaIMUCMAWIMkKd7ISH703C3cGaAGAPEGCDGADEGiDFAjAFiDBBjgEx9nL2U8r0sy9d5m/MyPiPi47hwKobO5WEKxBggxgAxBogxQIwBYgwQY4AYA+QCAU2emmkVh18AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAE/CAYAAADCGpEOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3yV9d3/8dcnG0LCDmGEhB2GCBIR3CICLtxb66poe6v9VVtra8fd2nXr3drWahWrVVt360BxgROVFfaSGVZIQiCQhJCd7++PRG9UxjnJOec64/18PPKQk4RzvS8Cb6/vNb5fc84hIiK+ifM6gIhIJFFpioj4QaUpIuIHlaaIiB9UmiIiflBpioj4IcHrAG3RrVs3l5OT43UMEYkyixYt2uWc636wr0V0aebk5JCfn+91DBGJMma25VBf0/BcRMQPKk0RET+oNEVE/KDSFBHxg0pTRMQPKk0RET+oNEVE/KDSFBHxg0pTRMQPEf1EkEioNDY5dlfVsquyjtJ9tVRU19PQ1ER9Y/PKB6lJCbRPjic9JZFenVLISEshPs48Ti3BoNIU+ZrqukaWbN3Doi17WLdzH+tLKtlUWkVdY5PP7xEfZ/TqlEJuZjrDeqYzrFc6edmd6dohOYjJJRRUmhLznHOs2lHBu6tL+HhdKSsLy2loaj6C7NO5HYN7pHHK4O706dyObh2S6Z6WTHq7RBLj40iMN5yD/XWNVNU1UL6/nh3l1ezYW83m3ftZU1TB7DUlfLEU19Ce6Rw/oCunDclgXP8uJMTrDFmkUWlKzFpbXMlL+dt4a2UxhXuriTMY3bcz007uz7H9ujAmuzPpKYlt3k51XSOri8qZt6mMzzbu4l/ztvD4JwV0bp/I5OGZnHt0L8b370qchvMRwSJ5Ncq8vDynWY7EHzX1jby6pJDnFm5j2ba9JMYbpwzuzqThmZyemxGS4XNNfSMfri3lzRVFvLemhKq6RnK6tufK4/py8ZgsuqQmBT2DHJ6ZLXLO5R30aypNiQV7qup4eu4Wnp67md1VdQzu0YFL87K4YHRvT88z1tQ38vbKYp6Zv4WFm/eQlBDHJWP6cMspA8jq0t6zXLFOpSkxa+/+Ov720Uae/mwL1fWNTMjNYNrJ/TmuXxfMwms4vLa4kic/K+A/iwppdI7zRvXi9gmDyOmW6nW0mKPSlJhTXdfIE58W8MhHG9lX28B5R/fiu6cNZHCPNK+jHVFxeQ2PzdnEs/O3Ut/YxNXjsrn99EEatoeQSlNihnOOd1YV86vXV7OjvIaJQzP4weQh5Gamex3Nbzsra/jT7PU8v2ArqckJ3DZhINef0I9EXXEPOpWmxIQtu6v4xYxVfLi2lNzMNH45dTjH9e/qdaw2W19SyW/fXMMHLfv1uwuPYnTfzl7HimoqTYlqTU2OJz4t4P531pIQZ9wxaQjXjs+Ounsg31lVzC9eW0VJZQ3XjMvmR1NySU3WXYPBcLjS1J+4RLRtZfv5wUvLmF9QxsShPfj1+SPI7JjidaygmDw8kxMGduN/31nLU3M38/G6Uh64bJSOOkMsuv5XLDHl34u2c+af57BqRwX3XzySx741JmoL8wsdkhP476nDee6mcdQ3Oi5+ZC5/nr2eBj8e8ZS2UWlKxKmua+SHLy3jBy8tY3ivdN763klckpcVdrcQBdO4/l1583snce7Injwwex1XPDaPnRU1XseKCSpNiSgFu6q44OFPeWnRdm6fMJBnbxoXszeBd2yXyJ8uH82fLhvFysIKzn7wExYUlHkdK+qpNCVivLemhKkPfkJxRQ3/uP5Y7pg0RNOvAeeP7s2r/3UCHZITuPKxeTz+SQGRfIE33Kk0Jew55/j7nE18++l8sru1543bTuS0IRlexworQzLTeO3WEzgtN4N731jN3f9ZQb3OcwaFSlPCWn1jE/e8upJfz1zD5GGZvHjzePp0js3h+JGkpyTy6NVjuG3CQF7I38Z1/1hAeXW917GijkpTwlZVbQM3PLmQZ+dv5ZZTBvDwVcfQPkl3yR1OXJxx56Qh3HfxSOZvKuPiv33G9j37vY4VVVSaEpb2VNVx5d/n89nG3dx30UjuPjNX80364dK8LJ6+YSzFFTVc+PBnrCup9DpS1FBpStgpLq/h0kfnsqaogkeuHsOlx2Z5HSkiHT+wG//5zvEAXPboXJZv3+txouig0pSwsnlXFRc/8hlF5TU8df1YzhjWw+tIEW1wjzReumU8qckJXPnYfOZt2u11pIin0pSwUbCriksfnUtVbQPP3nQc4wdE/mQb4SC7ayr/vuV4MjumcO0TC/h4XanXkSJaWJWmmT1hZjvNbKXXWSS0Nu+q4orp82hocjw/bTwj+3TyOlJUyeyYwgvTxtG/ewduejqfzzbs8jpSxAqr0gSeBKZ4HUJCa+vu/Vzx2DxqGxp59qbjGJIZ/hMFR6KuHZL5141jye7anhufyme+huqtElal6Zz7GNBzYDFkW1lzYVbXN/LMt8dF5GTBkaRrh2Se+fY4enVK4fonF7Joi/65+SusSlNiy659tVzz+Hwqa+r5143HMayXCjMUuqcl89xN4+iRnsJ1TyxkZWG515EiSsSVpplNM7N8M8svLdUJ7Ui1r7aB6/+xsOU58rGM6N3R60gxJSM9hWdvOo70dolc94+FbN2tG+B9FXGl6Zyb7pzLc87lde/e3es40gq1DY3c/M98VhdV8PBVxzAmW5PoeqFnx3Y8dcNYGpqauOaJ+ZRW1nodKSJEXGlKZGtsctzx4jI+3dD8pM+EXN2H6aWBGR144rpjKamo4fonF7CvtsHrSGEvrErTzJ4D5gJDzGy7md3odSYJrN+/tYaZy4v4yVm5XDSmj9dxBDimb2cevuoY1hRV8p1/LdLsSEcQVqXpnLvCOdfTOZfonOvjnHvc60wSOM8t2Mpjcwq4dnw2004e4HUcOcCE3B787sKjmLN+F7+YsUrzcR6GpoyRkPh0wy5+9upKThncnZ+dM8zrOHIQl+Zlsam0ikc+2sjA7h244cR+XkcKSypNCbqNpfv4zr8W0b97Kg9eOTrqltaNJndNHsKm0n38euZq+nVP1WTPB6G/vRJUe6rquPHJhSTGx/H4tceSnpLodSQ5jLg444HLRpGbmc5tzy5hbbGmlPs6laYETWOT4/bnl7Bjbw3TvzUmZhdAizSpyQn8/do82iXFc9PT+ZTv1+zvB1JpStA8MGsdc9bv4lfnDWdMdhev44gfenVqxyNXj6GovJrvv7iUpiZdGPqCSlOC4t1Vxfz1gw1cfmwWl4/t63UcaYUx2Z352TnDeP/znfz1gw1exwkbKk0JuIJdVdz54jJG9unIf08d7nUcaYNrxmVzwejePDB7HR+u3el1nLCg0pSA2l/XwC3/XERCvPHwVceQkhjvdSRpAzPjtxccxZAeaXzv+aVsK9Mz6ipNCaifvrqS9TsrefCKY7TUbpRolxTPI1ePock5vvvMYmobGr2O5CmVpgTMy4u38/LiQm4/fRAnDurmdRwJoJxuqfzvJUezorCc+95e63UcT6k0JSAKdlXx01dXMrZfF26bMMjrOBIEk4dn8q3x2Tz+SQEffB675zdVmtJmtQ2N3PbcYpIS4vjz5aOI1/rkUesnZw0lNzONH7y0jJ0VNV7H8YRKU9rsvrfXsrKwgvsuGknPju28jiNBlJIYz1+vHM3+usaYvX9TpSlt8sHnO3n8k+aZiyYNz/Q6joTAwIw0/nvqMD7dsJu/fbTR6zghp9KUViurquOH/15ObmYaPz5rqNdxJIQuzcvinJE9eWDWuphbY0ilKa3inOOnr66gvLqOBy4bpfsxY4yZ8Zvzj6JrhyTueHEpNfWxcxuSSlNaZcayHby5opjvnzGYoT21imQs6tg+kf+5aCTrSvbxx1nrvI4TMipN8VtxeQ0/e3UlY7I7c7NmYI9ppw7J4Krj+vLYnE0sKIiNNdRVmuIX5xx3/Wc59Y2OP1xytG4vEn5y1lCyOrfnzpeWxsTCbCpN8csz87fy8bpSfnL2UHK6pXodR8JAanICf7j0aLbvqeY3M9d4HSfoVJris+179vPbN9dw0qBuXH2cpnuT/3NsThemndSf5xZs5bMNu7yOE1QqTfGJc457XlkJwO8uPAozDcvlq75/xmByurbn7pdXUF0XvVfTVZrik9eW7uCjdaXcNXmIZi+Sg0pJjOd3F45ka9l+HpgdvVfTVZpyRLv31fLL11cxum8nrhmf43UcCWPjB3TlirFZ/H3OJlZsj86b3lWackT3vrGafbUN/M9FI3W1XI7o7jOH0q1DcstdFk1exwk4laYc1gdrd/Lq0h1899SBDO6R5nUciQAd2yVy7/kjWFNUwfSPN3kdJ+BUmnJIVbUN/PSVlQzK6MB3T9NN7OK7ycMzOXNEJn9+bz2bSvd5HSegVJpySH95bz2Fe6v5/UVHkZygZ8vFP788bzjJ8XH8YsYqnIueKeRUmnJQ60oqefyTAi7Ly9Ka5dIqGWkp3DlpMHPW7+LNFcVexwkYlaZ8g3OOn726kg4pCfzozFyv40gEu3pcNsN6pn95MTEaqDTlG15buoP5BWXcNTmXLqlJXseRCJYQH8e954+guKKGv7y33us4AaHSlK+oqKnn1zPXcHRWJy4/NsvrOBIFxmR35vJjs3j8kwLWFld6HafNwqo0zWyKma01sw1mdrfXeWLRH99dx+6qWu49bzhxuidTAuSuKbmkpSTws9dWRvxFobApTTOLBx4CzgSGAVeY2TBvU8WWVTvKeXruZq46ri8j+3TyOo5EkS6pSfxoSi4LCsp4ZUmh13HaJGxKExgLbHDObXLO1QHPA+d5nClmOOf45eur6dQ+iR9O0sUfCbzL8rI4OqsTv3/rc6oi+KJQOJVmb2DbAa+3t3xOQuDtlcUsKCjjzkmD6dg+0es4EoXi4oxfnDuMnZW1/O3DyF3FMpxK82An0L5x8sPMpplZvpnll5aWhiBW9Kupb+S3b60hNzONy/J08UeC55i+nTlvVC+mz9nEtrL9XsdplXAqze3Agf9i+wA7vv5Nzrnpzrk851xe9+7dQxYumj3xaQHbyqr52TnDSIgPp78SEo1+NCWXOIPfv/2511FaJZz+hSwEBplZPzNLAi4HZnicKertrKzhofc3MHFoD04Y2M3rOBIDenVqx80nD2Dm8qKIXIwtbErTOdcA3Aq8A6wBXnTOrfI2VfT7wzvrqGts4p6zh3odRWLILacMoGfHFH71xiqamiLrFqSwKU0A59ybzrnBzrkBzrnfeJ0n2q0sLOfFRdu47vgc+mmRNAmhdknx/GhKLisLK/j34u1ex/FLWJWmhI5zjnvfWE3n9kncOmGQ13EkBp03qhej+3bi/nfWsr8ucm5BUmnGqNlrdjK/oIzvnzGYju10i5GEnpnx07OHUlpZy+NzCryO4zOVZgxqaGzivrc/p3+3VD1fLp4ak92FScN68OjHm9i9r9brOD5RacaglxcXsn7nPn44eQiJusVIPHbXlFyq6xt58P0NXkfxif7FxJia+kb+OGsdo7I6MWVEptdxRBiY0YFL87J4Zv4Wtuyu8jrOEak0Y8yTn22muKKGu8/MxUyzGEl4+H8TBxEfZ9z/zlqvoxyRSjOG7N1fx8MfbGBCbgbj+nf1Oo7Il3qkp/DtE/vzxvIilm/f63Wcw1JpxpC/fbiRytoG7poyxOsoIt9w8yn96ZKaxO/f+jys59xUacaIwr3V/OOzzVw4ug+5melexxH5hrSURG6bMJDPNu7mo3XhOxmPSjNG/GnWOgDumDTY4yQih3bVcdn06dyOP7y7LmyPNlWaMaBgVxUvLynk6uOy6d2pnddxRA4pKSGO208fxIrCct5dXeJ1nINSacaAP89eR1J8HN85dYDXUUSO6MLRvenXLZU/vrsuLCfzUGlGufUllby2bAffOj6b7mnJXscROaKE+Dj+38RBrC2pZOaKIq/jfINKM8r96b31tE+M5+aTdZQpkeOckb0Y3KMDD8xeR0Njk9dxvkKlGcXWFFUwc3kR15/Qjy6pSV7HEfFZfJzx/YmD2VRaxWtLv7GAg6dUmlHsgVnrSEtJ4KaT+nsdRcRvk4dnMrxXOn9+bz31YXS0qdKMUiu2N199/PaJ/bW6pESkuDjjjjMGs7VsP/9eFD4TFas0o9QfZ62lY7tEbjgxx+soIq02ITeDUVmd+Mt766ltaPQ6DqDSjEqLt+7hg7WlTDu5P2kpOsqUyGXWfLRZVF7DfxYVeh0HUGlGpQffW0+X1CSuOz7H6ygibXbSoG4cndWJhz/cEBbnNlWaUWZlYTkfrC3lxhP7kZqc4HUckTYzM247bSDb91Tz6hLvjzZVmlHmr+9vIC0lgWvGZ3sdRSRgTh+awbCe6Tz84UYaPX5KSKUZRdaVVPL2qmKuPz6HdJ3LlChiZtw2YSAFu6p4Y7m3922qNKPIQx9soH1SPNef0M/rKCIBN3l4JoMyOvDQBxs8fSZdpRklCnZV8fqyHVwzLpvOevpHolBcnHHrhIGsK9nHO6uKvcvh2ZYloP724QYS4+O48SQdZUr0OmdkL/p1S+XB9zd4Nt+mSjMKbN+zn5cXF3LF2L5kpKV4HUckaOLjjO+eOoDVRRW8//lOTzKoNKPAox9twgymnaxnzCX6nT+6N306t/PsaNPv0jSzVDOLD0YY8V9JRQ0v5G/j4jF96KVZ2SUGJMbHcfMpA1i6bS/zNpWFfPtHLE0zizOzK81sppntBD4HisxslZndb2aDgh9TDuWJTwtoaGzillM0X6bEjkvG9KFbhyQe/XhjyLfty5HmB8AA4MdApnMuyzmXAZwEzAN+b2ZXBzGjHEJFTT3PztvKWUf1JLtrqtdxREImJTGe647P4cO1pawpqgjptn0pzYnOuXudc8udc18++OmcK3PO/cc5dxHwQvAiyqE8N38rlbUNmpVdYtLV47JpnxTP9I83hXS7RyxN51x9IL7ncMzskpbhfpOZ5bXlvWJFXUMTT3xawPEDunJUn45exxEJuU7tk7hibF9mLNvB9j37Q7bdVl09N7N7zewlM3vSzIYEIMdK4ELg4wC8V0x4bWkhJRW13KxzmRLDbjyxHwb8fU5ByLbZ2luOOjnnLgGmAbe3NYRzbo1zbm1b3ydWNDU5pn+8idzMNE4e1M3rOCKe6dWpHVNH9eKFhdvYU1UXkm22tjTrzGw04ABdgQixD9buZP3OfdxyygDMzOs4Ip665ZQBVNc38vTcLSHZXmtL8x5gIjAdHy8CmdlsM1t5kI/z/NmwmU0zs3wzyy8tLW1F9Mj36Eeb6N2pHWeP7Ol1FBHPDe6Rxum5GTw1dzPVdcFfEsPn0jSzP1nLYY1zrsY5d79z7nrn3Fu+/H7n3ETn3IiDfLzmT2Dn3HTnXJ5zLq979+7+/NaosHjrHhZsLuPGE/uRGK8HukQAbj5lAGVVdbyYvy3o2/LnX90+YIaZpQKY2SQz+zQ4seRQpn+0iY7tErns2Cyvo4iEjWNzOnNM3048/klB0Ccp9rk0nXM/BZ4DPjSzT4A7gbsDEcLMLjCz7cB4YKaZvROI9402m3dV8c7qYq4Zl62lLEQOYGbceGJ/tpbtZ/aakqBuy5/h+enATUAV0B243Tk3JxAhnHOvOOf6OOeSnXM9nHOTA/G+0ebJzzaTEGd863gtZSHydZOH96B3p3Y8/klwbz/yZ3h+D/Bz59ypwMXAC2Y2ISip5BvKq+t5MX8b5x7dS9O/iRxEQnwc1x2fw4KCMlYWlgdtO/4Mzyd8cWTpnFsBnAn8OljB5KteXLiN/XWN3KClLEQO6bKxWaQmxQf1aNOXWY76HuwDSARuPOBz6UFLGeMaGpt48rPNHNevCyN665FJkUNJT0nkkrwsXl+2g5KKmqBsw5erCU8d5msOsJb/Pgk8HYBM8jWzVpdQuLean587zOsoImHv+hNyeGruZp6eu5kfTs4N+PsfsTSdc6cFfKvil8c/KSCrSzsmDu3hdRSRsJfdNZUzhvbgmflbufW0QbRLCuyc6b4Mz3PM7D4ze9nM/m5mt7YMzyUElm3bS/6WPVx3fD/i4/TIpIgvbjyxH3v31/Pyku0Bf29fLgS9BqwFHgLOAI4G5pjZQ2aWHPBE8hX/+LSADskJXJrXx+soIhFjbL8ujOidzhOfFAR8jXRfSjPeOfe4c+49oMw5dxPNM7lvpvnZcwmSkooa3lhexCV5fUhLSfQ6jkjEaL7ZvR8bS6v4aH1g56jwpTRnm9mtLb92AM65Bufc/TQ/wSNB8s+5W2h0juuOz/E6ikjEOfuoXmSkJfPkp5sD+r6+lOYdQEczywd6tcwydLWZPQTsDmga+VJNfSPPzN/CxKE9tP6PSCskJcRx9bhs5m7cTWllbcDe15er503Ab8zsAZqngxsFdKZ5tvV7ApZEvmLG0h3s2V/P9SfkeB1FJGJdOz6Hy4/Nonta4C6/HLE0zcxcs/3AjJaPg35PwFLFOOccT8/bzOAeHRjfv6vXcUQiVsf2iTQ/hxM4Pi3ha2a3ff02IzNLMrMJZvYUcG1AU8W4Jdv2srKwgmvGZWtmdpEw48sTQVOAG4DnzKw/sAdoR3Phvgs84JxbGryIseefc7fQITmBC47RbUYi4caXc5o1wMPAw2aWCHQDqp1ze4MdLhbt3lfLzOVFXD42iw6aM1Mk7Pgzn+aZwBzgQ2C6mY0LVqhY9kL+Nuoam7hmnObMFAlH/syn+TDNs7WPo/mm9v81syuCkipGNTY5npm3lXH9uzCoR5rXcUTkIPwpzRLn3KfOuT3OudnAZHTLUUC9//lOCvdW863xOV5HEZFD8Kc0N5vZr80sqeV1PVAZhEwx6+m5m+mRnswZwzSbkUi48qc0HXAhsK1lYbUNNC+yNigoyWJMwa4q5qzfxZVjs7U0r0gY8/nyrHPuCgAzSwFG0Dzb0dHA382sv3NOa8q2wb/mbSEhzrhirP4YRcKZ3/e0tNyClN/yIQFQXdfIS/nbmDIik4x0LZomEs58Lk0zu+Mgny4HFunm9rZ5fdkOKmoadJuRSATw5+RZHnAL0LvlYxpwKvCYmd0V+Gix47mFWxmY0YGx/bp4HUVEjsCf0uwKHOOcu9M5dyfNJdodOBm4LgjZYsKaogqWbN3LFWP76jlzkQjgT2n2BeoOeF0PZDvnqoHATVYXY55fsJWkhDguHN3b6ygi4gN/LgQ9C8wzs9doXrb3HJon8UgFVgcjXLSrrmvk5SWFnDUik86pSUf+DSLiOX9uObrXzN4ETqS5NG9xzn1xBf2qYISLdjNXFFFZ08AVY7W4p0ik8PeWowagieYb3esDHye2PLdgK/27p+oCkEgE8WeWo+8Bz9A8NVwG8C8zuy1YwaLd2uJKFm3ZwxXH6gKQSCTx50jzRuA451wVgJn9DzAXeDAYwaLdcwu2khQfx0VjNNGwSCTx5+q5AY0HvG5s+Zz4qaa+kZcXb2fyiEy66AKQSETx50jzH8B8M3uF5rI8H3giECHM7H7gXJpvadoIXB/NM8O/uaKIipoGPWcuEoF8PtJ0zv0RuJ7mtc53A9c65x4IUI5ZwAjn3EhgHfDjAL1vWHpuwVZyurbXSpMiEciXJXwrab5a/uWnDviac86ltzWEc+7dA17OAy5u63uGq/UllSzcvIcfn5mrC0AiEciXhdVCve7CDcALId5myDy/cBuJ8aYLQCIRKmTLHZrZbCDzIF+6xzn3Wsv33EPzvaDPHOZ9ptE8WQh9+0bWTeF1DU28sqSQiUN70K1DstdxRKQVQlaazrmJh/u6mV1L86OZpzvn3KG+zzk3neaF3cjLyzvk94Wj9z/fSVlVHZfm6QKQSKQKi4W1zWwK8CPgFOfcfq/zBMtL+dvISEvmpEHdvI4iIq0ULovR/BVIA2aZ2VIze8TrQIG2s7KGD9eVctGYPiRoDSCRiBUWR5rOuYFeZwi2VxYX0tjkuEQXgEQimg55QsA5x4v52xiT3Zn+3Tt4HUdE2kClGQJLtu1lY2kVl+bpKFMk0qk0Q+Cl/O20S4zn7JG9vI4iIm2k0gyy6rpGXl+2gzOPyqRDclicQhaRNlBpBtnbq4rYV9ugezNFooRKM8heyt9O3y7tOU6zs4tEBZVmEG0r289nG3dz8Zg+mpxDJEqoNIPo34u2Y4Ym5xCJIirNIGlqcvxn8XZOHNiN3p3aeR1HRAJEpRkkCzaXsX1PNRcdo6NMkWii0gySVxYX0j4pnknDe3gdRUQCSKUZBDX1jby5oogpIzJpn6R7M0WiiUozCGavKaGytkFDc5EopNIMglcWF5KZnsI4LZwmEnVUmgG2a18tH60r5bzRvYiP072ZItFGpRlgry/bQUOT48LRGpqLRCOVZoC9sqSQYT3TGZIZ6kU8RSQUVJoBtGHnPpZvL+fCY3p7HUVEgkSlGUCvLNlOnMHUozVvpki0UmkGSFOT49UlOzhxUHcy0lO8jiMiQaLSDJAFm8so3FvNhaM1NBeJZirNAHllcSGpemxSJOqpNAPg/x6b7KnHJkWinEozAN5bs5PK2gYu0NBcJOqpNAPgtaWFZKQlM36AHpsUiXYqzTYqr67nw7WlnD2ypx6bFIkBKs02emdVMXWNTbo3UyRGqDTb6PVlO+jbpT2jsjp5HUVEQkCl2QallbV8umEXU4/updUmRWKESrMN3lxRRJODqaM0NBeJFSrNNpixbAe5mWkM7qEZjURihUqzlbaV7WfRlj2cqwtAIjElLErTzO41s+VmttTM3jWzsG+i15fvADSjkUisCYvSBO53zo10zo0C3gB+7nWgI5mxdAej+3Yiq0t7r6OISAiFRWk65yoOeJkKOK+y+GJ9SSWfF1dyno4yRWJO2MwuYWa/Ab4FlAOneRznsGYs20GcwdkjVZoisSZkR5pmNtvMVh7k4zwA59w9zrks4Bng1sO8zzQzyzez/NLS0lDF/5JzjhnLdnD8gG50T0sO+fZFxFshO9J0zk308VufBWYCvzjE+0wHpgPk5eWFfBi/fHs5W3bv579OHRjqTYtIGAiLc5pmNuiAl1OBz73KciQzlu0gKT6OySMyvY4iIh4Il3OavzezIUATsAW4xeM8B9XU5Hhj+UgrlfwAAAhOSURBVA5OGdKdju0SvY4jIh4Ii9J0zl3kdQZfLNq6h5KKWs4Z2dPrKCLikbAYnkeKmcuLSEqI4/ShWgdIJFapNH3U1OR4a2URpw3pTofksDhAFxEPqDR99MXQ/KyjNDQXiWUqTR/NXF5EsobmIjFPpemDpibHmyuKOFVDc5GYp9L0Qf6WPeysrNVjkyKi0vTFmytahua5GV5HERGPqTSP4Iuh+WlDMkjV0Fwk5qk0j+CLoflZuqFdRFBpHtHM5Ts0NBeRL6k0D6OxyfHWymINzUXkSyrNw8jfXKahuYh8hUrzMHTVXES+TqV5CI1Njjc1NBeRr1FpHkL+5jJKK2s5W0NzETmASvMQ3lpZTHJCHBM0NBeRA6g0D6KpyfHOqmJOHtxdQ3MR+QqV5kEsLyynqLyGKcO1DpCIfJVK8yDeXllMQpxx+lANzUXkq1SaX+Oc4+2VRYwf0JVO7ZO8jiMiYUal+TXrSvaxefd+JmtoLiIHodL8mrdXFmMGk4ZphnYR+SaV5te8vaqYMX07k5Ge4nUUEQlDKs0DbNldxZqiCqaM0NBcRA5OpXmAd1YVA+h8pogckkrzAG+vLGZ4r3SyurT3OoqIhCmVZouSihoWb92rG9pF5LBUmi3ebRma63ymiByOSrPF26uK6d89lYEZHbyOIiJhTKUJ7KmqY96mMqYMz8TMvI4jImFMpQnMXlNCY5PT0FxEjkilSfOtRr06pnBU745eRxGRMBdWpWlmPzAzZ2bdQrXNfbUNfLx+F5NHaGguIkcWNqVpZlnAGcDWUG73o7Wl1DU06YZ2EfFJ2JQm8ABwF+BCudFZq4vp3D6RvOzOodysiESosChNM5sKFDrnloVyu/WNTbz/+U4m5PYgIT4s/ihEJMyFbAEcM5sNHGwMfA/wE2CSj+8zDZgG0Ldv3zZlWri5jIqaBs4YphnaRcQ3IStN59zEg33ezI4C+gHLWi7E9AEWm9lY51zxQd5nOjAdIC8vr01D+VmrS0hKiOOkQd3b8jYiEkM8X2rRObcC+PJQz8w2A3nOuV1B3i6zVpdw4sBuWnFSRHwWsyfy1pZUsn1PNWdohnYR8UPYHWI553JCsZ1Zq0oAOD1X5zNFxHcxe6Q5a00Jo7I6aVkLEfFLTJZmcXkNy7eXa2guIn6LydKcvaZ5aK7SFBF/xWRpzlpdQnbX9gzS3Jki4qeYK819tQ3M3bibM4b20AQdIuK3mCvNj9eVUtfYxEQNzUWkFWKuNGetLqGTJugQkVaKqdJs+HKCjgxN0CEirRJTzbFw8x7Kq+uZpKG5iLRSTJWmJugQkbaKmdJ0zjFrTTEnDOiqCTpEpNVipjQ37NzHtrJqzhimZS1EpPVi5pBrYEYHZt5+Ir07tfM6iohEsJgpTTNjeC8t0SsibRMzw3MRkUBQaYqI+EGlKSLiB5WmiIgfVJoiIn5QaYqI+EGlKSLiB5WmiIgfVJoiIn5QaYqI+MGcc15naDUzKwW2eJ0jALoBu7wOESTRum/Rul8Qvfvmz35lO+cOOodkRJdmtDCzfOdcntc5giFa9y1a9wuid98CtV8anouI+EGlKSLiB5VmeJjudYAgitZ9i9b9gujdt4Dsl85pioj4QUeaIiJ+UGmGkJlNMbO1ZrbBzO4+yNdPNrPFZtZgZhd7kbE1fNivO8xstZktN7P3zCzbi5yt4cO+3WJmK8xsqZl9YmbDvMjZGkfatwO+72Izc2YWEVfUffiZXWdmpS0/s6Vm9m2/NuCc00cIPoB4YCPQH0gClgHDvvY9OcBI4GngYq8zB3C/TgPat/z6O8ALXucO4L6lH/DrqcDbXucO1L61fF8a8DEwD8jzOneAfmbXAX9t7TZ0pBk6Y4ENzrlNzrk64HngvAO/wTm32Tm3HGjyImAr+bJfHzjn9re8nAf0CXHG1vJl3yoOeJkKRMpFgiPuW4t7gfuAmlCGawNf96vVVJqh0xvYdsDr7S2fi3T+7teNwFtBTRQ4Pu2bmf2XmW2kuVxuD1G2tjrivpnZaCDLOfdGKIO1ka9/Hy9qOV30bzPL8mcDKs3QsYN8LlKOSg7H5/0ys6uBPOD+oCYKHJ/2zTn3kHNuAPAj4KdBTxUYh903M4sDHgDuDFmiwPDlZ/Y6kOOcGwnMBp7yZwMqzdDZDhz4f7Q+wA6PsgSST/tlZhOBe4CpzrnaEGVrK39/Zs8D5wc1UeAcad/SgBHAh2a2GRgHzIiAi0FH/Jk553Yf8HfwMWCMPxtQaYbOQmCQmfUzsyTgcmCGx5kC4Yj71TLMe5TmwtzpQcbW8mXfBh3w8mxgfQjztcVh9805V+6c6+acy3HO5dB8Lnqqcy7fm7g+8+Vn1vOAl1OBNf5sIKHNEcUnzrkGM7sVeIfmK3xPOOdWmdmvgHzn3AwzOxZ4BegMnGtmv3TODfcw9hH5sl80D8c7AC+ZGcBW59xUz0L7yMd9u7XlKLoe2ANc611i3/m4bxHHx/263cymAg1AGc1X032mJ4JERPyg4bmIiB9UmiIiflBpioj4QaUpIuIHlaaIiB9UmiIiflBpioj4QaUpUc3M4s3sz2a2qmXey/5eZ5LIptKUaPdjYFPLk1V/Ab7rcR6JcHqMUqKWmaUCFzjnvpiQoYDm58NFWk2lKdFsIpBlZktbXneheSowkVbT8Fyi2Sjg5865Uc65UcC7wNIj/B6Rw1JpSjTrDOwHMLMEYBLNE9CKtJpKU6LZOponzwX4PjDTOVfgYR6JApoaTqKWmXWmeT2ibsBcYJpzrtrbVBLpVJoiIn7Q8FxExA8qTRERP6g0RUT8oNIUEfGDSlNExA8qTRERP6g0RUT8oNIUEfHD/wdnZKXh3ld56AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "DIR = 'results/IS'                                             \n",
    "RERUN = not utils_os.is_file_exist(DIR, 'true_samples.npy') \n",
    "\n",
    "## Define the problem\n",
    "problem = problem_IS.IS_Problem(N=500, n=1)\n",
    "true_theta = problem.get_true_theta()\n",
    "\n",
    "## Get x_o ~ p(x|theta)\n",
    "if RERUN:\n",
    "    # observed data x_o\n",
    "    problem.data_obs = problem.simulator(true_theta)\n",
    "    problem.y_obs = problem.statistics(data=problem.data_obs)\n",
    "    utils_os.save_object(DIR, 'data_obs', problem.data_obs)\n",
    "    utils_os.save_object(DIR, 'y_obs', problem.y_obs)\n",
    "else:\n",
    "    problem.data_obs  = utils_os.load_object(DIR, 'data_obs.npy')\n",
    "    problem.y_obs  = problem.statistics(data=problem.data_obs)\n",
    "    \n",
    "\n",
    "## Get True posterior (rejection sampling approximation with 1D sufficient stat)\n",
    "print('\\n sampling from true posterior ... \\n')\n",
    "hyperparams = ABC_algorithms.Hyperparams()\n",
    "hyperparams.save_dir = DIR\n",
    "hyperparams.num_sim = 40000\n",
    "hyperparams.num_samples = 150\n",
    "hyperparams.device = 'cuda:3'\n",
    "hyperparams.L = 1\n",
    "tp_abc = TPABC.TP_ABC(problem, discrepancy=discrepancy.eculidean_dist, hyperparams=hyperparams)\n",
    "if RERUN:\n",
    "    tp_abc.run()\n",
    "    true_samples = tp_abc.rej_samples\n",
    "    utils_os.save_object(DIR, 'true_samples', true_samples)\n",
    "else:\n",
    "    tp_abc = utils_os.load_algorithm(DIR, tp_abc)\n",
    "    true_samples = utils_os.load_object(DIR, 'true_samples.npy')\n",
    "    \n",
    "## Visualize\n",
    "problem.visualize()  \n",
    "visualization.plot_likelihood(samples=true_samples, log_likelihood_function=tp_abc.log_likelihood)\n",
    "plt.savefig('IS_true_posterior.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SMC-ABC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Sequential Monte Carlo ABC\n",
    "\n",
    "hyperparams = ABC_algorithms.Hyperparams()\n",
    "hyperparams.save_dir = DIR\n",
    "hyperparams.device = 'cuda:0'\n",
    "hyperparams.num_sim = 2000                        # number of simulations\n",
    "hyperparams.num_samples = 150                     # number of samples to represent posterior\n",
    "hyperparams.L = 2                                 # number of rounds in sequential learning\n",
    "\n",
    "smc_abc = SMCABC.SMC_ABC(problem, discrepancy=discrepancy.eculidean_dist, hyperparams=hyperparams)\n",
    "smc_abc.run()\n",
    "\n",
    "JSD_smc_array = []\n",
    "for l in range(hyperparams.L):\n",
    "    print('round =', l)\n",
    "    smc_abc.posterior = smc_abc.posterior_array[l]\n",
    "    visualization.plot_likelihood(samples=true_samples, log_likelihood_function=smc_abc.log_likelihood, dimensions=(0,1))\n",
    "    JSD = discrepancy.JSD(tp_abc.log_likelihood, smc_abc.log_likelihood, true_samples, true_samples, N_grid=30)\n",
    "    JSD_smc_array.append(JSD)\n",
    "    print('JSD smc = ', JSD)\n",
    "utils_os.save_object(DIR, 'JSD_SMC', JSD_smc_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Sequential Monte Carlo ABC +\n",
    "\n",
    "hyperparams = ABC_algorithms.Hyperparams()\n",
    "hyperparams.save_dir = DIR\n",
    "hyperparams.device = 'cuda:1'\n",
    "hyperparams.num_sim = 2000                       # number of simulations\n",
    "hyperparams.num_samples = 150                    # number of samples to represent posterior\n",
    "hyperparams.L = 2                                # number of learning rounds\n",
    "hyperparams.type = 'cnn2d'                       # the network architecture of S(x)\n",
    "hyperparams.stat = 'infomax'                     # statistics function: infomax/moment/score  \n",
    "hyperparams.estimator = 'JSD'                    # MI estimator; JSD or DC, see the paper\n",
    "\n",
    "smc2_abc = SMC2ABC.SMC2_ABC(problem, discrepancy=discrepancy.eculidean_dist, hyperparams=hyperparams)\n",
    "smc2_abc.run()\n",
    "\n",
    "JSD_smc2_array = []\n",
    "for l in range(len(smc2_abc.posterior_array)):\n",
    "    print('l=', l)\n",
    "    smc2_abc.l = l\n",
    "    smc2_abc.posterior = smc2_abc.posterior_array[l]\n",
    "    visualization.plot_likelihood(samples=true_samples, log_likelihood_function=smc2_abc.log_likelihood, dimensions=(0,1))\n",
    "    JSD = discrepancy.JSD(tp_abc.log_likelihood, smc2_abc.log_likelihood, true_samples, true_samples, N_grid=30)\n",
    "    JSD_smc2_array.append(JSD)\n",
    "    print('JSD smc2 = ', JSD)\n",
    "utils_os.save_object(DIR, 'JSD_SMC2', JSD_smc2_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SNL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Sequential Neural Likelihood\n",
    "hyperparams = ABC_algorithms.Hyperparams()\n",
    "hyperparams.save_dir = DIR\n",
    "hyperparams.device = 'cuda:1'\n",
    "hyperparams.num_sim = 4000\n",
    "hyperparams.L = 2\n",
    "\n",
    "print('\\n SNL ABC')\n",
    "snl_abc = SNLABC.SNL_ABC(problem, discrepancy=discrepancy.eculidean_dist, hyperparams=hyperparams)\n",
    "snl_abc.run()\n",
    "\n",
    "JSD_array = []\n",
    "for l in range(len(snl_abc.nde_array)):\n",
    "    print('l=', l)\n",
    "    snl_abc.nde_net = snl_abc.nde_array[l]\n",
    "    visualization.plot_likelihood(samples=true_samples, log_likelihood_function=snl_abc.log_likelihood, dimensions=(0,1))\n",
    "    JSD = discrepancy.JSD(tp_abc.log_likelihood, snl_abc.log_likelihood, true_samples, true_samples, N_grid=30)\n",
    "    JSD_array.append(JSD)\n",
    "    print('JSD snl = ', JSD)\n",
    "utils_os.save_object(DIR, 'JSD_SNL', JSD_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Sequential Neural Likelihood + \n",
    "hyperparams = ABC_algorithms.Hyperparams()\n",
    "hyperparams.save_dir = DIR\n",
    "hyperparams.device = 'cuda:1'\n",
    "hyperparams.num_sim = 4000                       # number of simulations\n",
    "hyperparams.L = 2                                # number of learning rounds\n",
    "hyperparams.type = 'cnn2d'                       # the network architecture of S(x)\n",
    "hyperparams.stat = 'infomax'                     # statistics function: infomax/moment/score   \n",
    "hyperparams.estimator = 'JSD'                    # MI estimator; JSD or DC, see the paper\n",
    "hyperparams.nde = 'MAF'                          # nde; MAF (D>1) or MDN (D=1)\n",
    "\n",
    "snl2_abc = SNL2ABC.SNL2_ABC(problem, discrepancy=discrepancy.eculidean_dist, hyperparams=hyperparams)\n",
    "snl2_abc.run()\n",
    "\n",
    "JSD_array = []\n",
    "for l in range(len(snl2_abc.nde_array)):\n",
    "    print('l=', l)\n",
    "    snl2_abc.set(l=l)\n",
    "    visualization.plot_likelihood(samples=true_samples, log_likelihood_function=snl2_abc.log_likelihood, dimensions=(0,1))\n",
    "    JSD = discrepancy.JSD(tp_abc.log_likelihood, snl2_abc.log_likelihood, true_samples, true_samples, N_grid=30)\n",
    "    JSD_array.append(JSD)\n",
    "    print('JSD snl+ = ', JSD)\n",
    "utils_os.save_object(DIR, 'JSD_SNL2', JSD_array)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
